-
Notifications
You must be signed in to change notification settings - Fork 90
fix: jax reducers returning incorrect output values or lengths #3464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I'm currently skipping the argmin argmax tests because of #3463 but that should change. |
awkward/src/awkward/_connect/jax/reducers.py Lines 296 to 299 in c93da2f
So we can't take the product of an array with negative numbers? The logarithm will just NaN the output |
I may have fixed that in 9874810. I still haven't made |
@ikrommyd - impressive! There are only 5 tests that fails: 3 for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into it!
@@ -261,7 +328,7 @@ def apply( | |||
if array.dtype.kind == "M": | |||
raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}") | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ikrommyd - if we want to allow sum of boolean, I think, we should view the data as integers here:
if array.dtype == np.bool_:
data = array.data.astype(jax.numpy.int32)
else:
data = array.data
So I made ci pass but there are a few things that are to be done for sure
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ikrommyd - just some minor comments. I agree, JAX backend needs to be thoroughly tested. Perhaps, our fellow could take over? @pfackeldey - when do we discuss his project? Thanks.
src/awkward/_connect/jax/reducers.py
Outdated
@@ -68,7 +131,7 @@ def segment_argmin(data, segment_ids): | |||
class ArgMin(JAXReducer): | |||
name: Final = "argmin" | |||
needs_position: Final = True | |||
preferred_dtype: Final = np.int64 | |||
preferred_dtype: Final = np.float64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argmin
returns the index (i.e. position) of the minimum value. Indices are always integers, not floating-point numbers.
src/awkward/_connect/jax/reducers.py
Outdated
@@ -125,7 +191,7 @@ def segment_argmax(data, segment_ids): | |||
class ArgMax(JAXReducer): | |||
name: Final = "argmax" | |||
needs_position: Final = True | |||
preferred_dtype: Final = np.int64 | |||
preferred_dtype: Final = np.float64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the same argument here - indices are always integers, not floating-point numbers.
Needs more work and I'd appreciate any help @ianna @pfackeldey.
I'm adding a test that tests all the reducers via parametrization.
Definitely needs #3457
This PR should fix #3456, #3462, #3463 and #3465